# -*- coding: utf-8 -*-
# %%
import numpy as np
import scipy.ndimage as ndimage
import scipy.io as sio
import matplotlib.pyplot as plt
import os.path
import matplotlib.patches as patches
import time
import sys
import h5py


# %%
'''def AtomCenter_Pos():
    Atom_Center = [np.array([208, 276, 0, 0]), np.array([257, 265, 0, 0]), np.array([307, 254, 0, 0]),
                   np.array([318, 305, 0, 0]), np.array([329, 354, 0, 0]), np.array([280, 366, 0, 0]),
                   np.array([230, 377, 0, 0]), np.array([217, 327, 0, 0]), np.array([269, 316, 0, 0]),
                   np.array([279, 107, 0, 0]), np.array([329, 94, 0, 0]), np.array([381, 83, 0, 0]),
                   np.array([392, 136, 0, 0]), np.array([404, 186, 0, 0]), np.array([354, 196, 0, 0]),
                   np.array([302, 209, 0, 0]), np.array([290, 155, 0, 0]), np.array([342, 147, 0, 0])]
    return Atom_Center'''


# %%
def Im_Rot_Copy(Image, Cut_Width, Rot_Copy_Switch, Cut_Switch, Width_Min):
    # AKA imRotCopy.m
    if Cut_Switch == 1:
        Cut_X, Cut_Y = np.meshgrid(np.arange(0, Cut_Width), np.arange(0, Cut_Width))
        Cut_Map_0 = np.zeros((Cut_Width, Cut_Width))
        FBZ_Center_x = np.ceil(Cut_Width / 2)
        FBZ_Center_y = np.ceil(Cut_Width / 2)
        Circle_Radial = FBZ_Center_x - Width_Min
        Range = np.sqrt(
            (Cut_X - FBZ_Center_x) * (Cut_X - FBZ_Center_x) + (Cut_Y - FBZ_Center_x) * (Cut_Y - FBZ_Center_y))
        Cut_Map_0[Range < Circle_Radial] = 1
        Cut_Map_1 = np.zeros(Cut_Width, Cut_Width)
        Cut_Map_2 = np.zeros(Cut_Width, Cut_Width)
        '''
        Cut_Map_0 is a matrix of dim Cut_Width, where the points in the circle of R=Circle_Radial at the center is 1,
        and others are 0.
        '''

        if Rot_Copy_Switch == 1:
            Cut_Map_1[(Range > Circle_Radial) and (Cut_X <= FBZ_Center_x) and (Cut_Y <= FBZ_Center_y)] = 1
            Cut_Map_1[(Range > Circle_Radial) and (Cut_X >= FBZ_Center_x) and (Cut_Y >= FBZ_Center_y)] = 1
        Cut_Map_1 = Cut_Map_1 * Image
        Cut_Map_2 = Cut_Map_2 * Image
        Cut_Map_1_Rot90 = ndimage.rotate(Cut_Map_1, 90)
        Cut_Map_2_Rot90 = ndimage.rotate(Cut_Map_2, 90)
        Cut_Map = Cut_Map_0 * Image + Cut_Map_1 + Cut_Map_1_Rot90 + Cut_Map_2 + Cut_Map_2_Rot90
    else:
        Cut_Map = Image
    return Cut_Map


# %%
def Find_Atom_Contour(Data, Threshold, Figure_Index):
    Data_B = Data > Threshold
    Data_Cold = Data * Data_B
    #    fig = plt.figure(Figure_Index)
    #    plt.subplot(121); plt.imshow(Data, vmin=-0.1, vmax=0.5, cmap='jet')
    #    plt.subplot(122); plt.imshow(Data_Cold, vmin=-0.1, vmax=0.5, cmap='jet')
    #    plt.title('Atom: '+str(Figure_Index))
    #    plt.pause(0.5)
    #    plt.close()
    return Data_Cold


# %%
def Cal_AtomSize(Od_Image, Threshold):
    """
    If the sum of the Od_Image is smaller than Threshold[0], set the size to 0.
    If none of the points in Od_Image is greater than Threshold[1], set the size to 15.
    Otherwise, the size is the average radius of those greater than Threshold[1].
    """
    if np.sum(Od_Image) < Threshold[0]:
        d = 0
    else:
        row, col = np.where(Od_Image > Threshold[1])
        if np.size(row) == 0:
            d = 16
        else:
            Size_Row = np.max(row) - np.min(row)
            Size_Col = np.max(col) - np.min(col)
            d = (Size_Row + Size_Col) / 2
            d = max((d, 16))
    return d


def Is_BEC(Od_Image, Atom_Size, Threshold):
    """
    Use Fi to 'label' atoms.
    If BECT=Od_Max/Atom_Size>T[0], Fi=0.
    If T[0]<BECT<T[1], Fi=0.5.
    If BECT>T[1], Fi=1.
    For different label, use different starting parameters for fitting.
    """
    Od_Max = np.max(Od_Image)
    BECT = Od_Max / Atom_Size
    if BECT <= Threshold[0]:
        Fi = 0
    elif Threshold[1] >= BECT > Threshold[0]:
        Fi = 0.5
    elif BECT > Threshold[1]:
        Fi = 1
    return Fi


# %%
def Atom_N1(Data_Ori, Center_Pos, Width, Background, Figure_Index,
            Matrix_Shift, Width_Min, Rot_Copy_Switch,
            Cut_Switch, Mode, Find_Atom_Contour_Threshold, **kw):
    Center_Pos = np.ceil(Center_Pos)
    # Width_Mult = 3.2
    if Mode == 0 or Mode == 4 or Mode == 5 or Mode == 100:
        if type(Width) == list:
            Start_x = int(Center_Pos[0] - np.floor(Width[0] / 2)) + 1
            Start_y = int(Center_Pos[1] - np.floor(Width[1] / 2)) + 1
            Vector_x = np.arange(Start_x, Start_x + Width[0]).reshape((1, Width[0]))
            Vector_y = np.arange(Start_y, Start_y + Width[1]).reshape((Width[1], 1))
        else:
            Start_x = int(Center_Pos[0] - np.floor(Width / 2)) + 1
            Start_y = int(Center_Pos[1] - np.floor(Width / 2)) + 1
            Vector_x = np.arange(Start_x, Start_x + Width).reshape((1, Width))
            Vector_y = np.arange(Start_y, Start_y + Width).reshape((Width, 1))
        Data_Cut = Data_Ori[Vector_y, Vector_x]
        if 'Circle_Cut' in kw and kw['Circle_Cut'] == 1:
            Data_Cut = Im_Rot_Copy(Data_Cut, Width, Rot_Copy_Switch, Cut_Switch, Width_Min)

        if Mode == 4:
            Data_Cut[Data_Cut < Background] = 0
            OD_Sum = np.sum(Data_Cut)
        elif Mode == 0:
            # Data_Cut[Data_Cut == 0] = Background
            Diff_Data_Cut = Data_Cut - Background
            # Diff_Data_Cut[Diff_Data_Cut < -0.005] = 1e-8
            #            Data_Cold = Find_Atom_Contour(Diff_Data_Cut, Find_Atom_Contour_Threshold, Figure_Index)
            OD_Sum = np.sum(Diff_Data_Cut)
        elif Mode == 100:
            Diff_Data_Cut = Data_Cut - Background
            # Diff_Data_Cut[Diff_Data_Cut < -0.005] = 1e-8
            #            Data_Cold = Find_Atom_Contour(Diff_Data_Cut, Find_Atom_Contour_Threshold, Figure_Index)
            OD_Sum = np.sum(Diff_Data_Cut)
    return OD_Sum, Diff_Data_Cut


# %%
def Save_And_Load_Data(Load_Name, Save_Name, File_Mode, **kw):
    """
    Save_Name entered without extension. If neither Save_Name.mat nor Save_Name.npy exists,
    load file from Load_Name. If either one exists, load it.
    File_Mode 0, raw pic; File_Mode 1, removed background with Matlab;
    File_Mode 2, removed background with python.
    """
    if os.path.exists(Save_Name + '.mat') == 0 and os.path.exists(Save_Name + '.npy') == 0:
        if File_Mode == 0:
            pass
        elif File_Mode == 1:
            if 'MatVersion' in kw and kw['MatVersion'] == '7.3':
                f = h5py.File(Load_Name, 'r')
                Od_Image = np.array(f['odimage'])
                Od_Image = Od_Image.T
            else:
                Image_Ori = sio.loadmat(Load_Name)
                Od_Image = Image_Ori['odimage']
        elif File_Mode == 2:
            Image_Ori = np.load(Load_Name)
            Od_Image = Image_Ori['odimage']
        print('Calculating ' + Load_Name)
        np.save(Save_Name + '.npy', Od_Image)
        sio.savemat(Save_Name + '.mat', {'odimage': Od_Image})
    elif os.path.exists(Save_Name + '.mat') == 1:
        if 'MatVersion' in kw and kw['MatVersion'] == '7.3':
            f = h5py.File(Load_Name, 'r')
            Od_Image = np.array(f['odimage'])
            Od_Image = Od_Image.T
        else:
            Image_Ori = sio.loadmat(Save_Name)
            Od_Image = Image_Ori['odimage']

        print('Calculating ' + Save_Name)
        if os.path.exists(Save_Name + '.npy') == 0:
            try:
                np.save(Save_Name + '.npy', Od_Image)
            except PermissionError:
                pass
    elif os.path.exists(Save_Name + '.npy') == 1:
        Image_Ori = np.load(Save_Name)
        Od_Image = Image_Ori['odimage']
        print('Calculating ' + Save_Name)
        if os.path.exists(Save_Name + '.mat') == 0:
            sio.savemat(Save_Name + '.mat', {'odimage': Od_Image})
    return Od_Image


# def Draw_Atoms(Data, Photo_Ind, x_Range, y_Range, SaveName):
#     fig = plt.figure(Photo_Ind)
#     # plt.imshow(Data_Ori, vmin=-0.1, vmax=0.5, cmap='jet')
#     plt.imshow(Data_Ori[170:350, 150:330], vmin=-0.1, vmax=0.5, cmap='jet')
#     fig.savefig(SaveName+'.png')
#     fig.savefig(SaveName+'.pdf')
#     plt.pause(0.5)
#     plt.close()
# %%
def Load_Pic_And_Cal(Photo_Ind, Load_Name, File_Mode, Save_Name, SaveFig_Address, Atom_Center, Cut_Width, Cloud_Number,
                     OriImage_Switch, **kw):
    # OriImage_Switch = 1
    # FitDraw_Switch = 0
    Wait_Sec = 1
    Atom_Mode = 0
    Find_Atom_Contour_Threshold = 0.05
    Background_Center = Atom_Center[-1]
    Background_Wid = Cut_Width[-1]
    Matrix_Shift = np.array([0, 0])
    while True:
        if os.path.exists(Load_Name):
            try:
                Data_Ori = Save_And_Load_Data(Load_Name, Save_Name, File_Mode)
                break
            except FileNotFoundError:
                print('FileNotFoundError')
                time.sleep(Wait_Sec)
            except OSError:
                print('OSError')
                time.sleep(Wait_Sec)
            except ValueError:
                print('ValueError')
                time.sleep(Wait_Sec)
            except NotImplementedError:
                Data_Ori = Save_And_Load_Data(Load_Name, Save_Name, File_Mode, MatVersion='7.3')
                break
            except:
                print('unknown error')
                time.sleep(Wait_Sec)
        else:
#            print('FileNotFoundError')
            time.sleep(Wait_Sec)
    #            return 'Error', Photo_Ind, sys.exc_info()
    if 'Rot_Ang' in kw:
        Data_Ori = ndimage.rotate(Data_Ori, kw['Rot_Ang'], reshape=False)
    x_Range = [0, np.shape(Data_Ori)[0]]
    y_Range = [0, np.shape(Data_Ori)[1]]
    if 'Cut_Aera' in kw:
        x_Range = kw['Cut_Aera'][0]
        y_Range = kw['Cut_Aera'][1]
    y_Start = y_Range[0]
    y_End = y_Range[1]
    x_Start = x_Range[0]
    x_End = x_Range[1]
    Region_All = Data_Ori[y_Start:y_End, x_Start:x_End].copy()
    if OriImage_Switch == 1:
        fig, ax = plt.subplots(1)

        plt.imshow(Region_All, vmin=-0.1, vmax=0.5, cmap='jet')
        theta = np.linspace(-np.pi, np.pi, 1000)
        for Atom_Ind in range(len(Atom_Center) - 1):
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == '0':
                break
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Rectangle':
                if type(Cut_Width[Atom_Ind]) == list:
                    rect = patches.Rectangle((Atom_Center[Atom_Ind][0] - x_Start - Cut_Width[Atom_Ind][0] / 2,
                                              Atom_Center[Atom_Ind][1] - y_Start - Cut_Width[Atom_Ind][1] / 2),
                                             Cut_Width[Atom_Ind][0], Cut_Width[Atom_Ind][1], linewidth=1, edgecolor='r',
                                             facecolor='none')
                else:
                    rect = patches.Rectangle((Atom_Center[Atom_Ind][0] - x_Start - Cut_Width[Atom_Ind] / 2,
                                              Atom_Center[Atom_Ind][1] - y_Start - Cut_Width[Atom_Ind] / 2),
                                             Cut_Width[Atom_Ind], Cut_Width[Atom_Ind], linewidth=1, edgecolor='r',
                                             facecolor='none')
                ax.add_patch(rect)
                continue
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Circle':
                if type(Cut_Width[Atom_Ind]) == list:
                    circlex = Cut_Width[Atom_Ind][0] / 2 * np.cos(theta) + Atom_Center[Atom_Ind][0] - x_Start
                    circley = Cut_Width[Atom_Ind][1] / 2 * np.sin(theta) + Atom_Center[Atom_Ind][1] - y_Start
                else:
                    circlex = Cut_Width[Atom_Ind] / 2 * np.cos(theta) + Atom_Center[Atom_Ind][0] - x_Start
                    circley = Cut_Width[Atom_Ind] / 2 * np.sin(theta) + Atom_Center[Atom_Ind][1] - y_Start
                plt.plot(circlex, circley, 'k')
        plt.plot()
        plt.text(75, 45, 'photo%d' % Photo_Ind, color='w')
        if 'Axis' in kw and kw['Axis'] == 'off':
            plt.axis('off')
        fig.savefig(SaveFig_Address + str(Photo_Ind) + r'Ori_Image' + '.png')
        #        fig.savefig(SaveFig_Address+str(Photo_Ind)+r'Ori_Image'+'.pdf')
        plt.show(block=False)
        plt.pause(0.5)
        # plt.close()
    Background = Atom_N1(Data_Ori, Background_Center, Background_Wid, 0, 50,
                         Matrix_Shift, 0, 0, 0, 100, 0)[0] / (Background_Wid * Background_Wid)

    # Atom_Nums = np.zeros([18, 1])
    Atom_Nums = []
    Atom_Regions = []

    Atom_Ind = 0
    #    plt.close('all')
    #    plt.imshow(Data_Ori, vmin=-0.1, vmax=0.5, cmap='jet')
    while Atom_Ind < Cloud_Number:
        Atom_Num, Atom_Region = Atom_N1(Data_Ori, Atom_Center[Atom_Ind], Cut_Width[Atom_Ind],
                                        Background, Atom_Ind, Matrix_Shift, 0, 0, 1, Atom_Mode,
                                        Find_Atom_Contour_Threshold)
        if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Circle':
            Atom_Num, Atom_Region = Atom_N1(Data_Ori, Atom_Center[Atom_Ind], Cut_Width[Atom_Ind],
                                            Background, Atom_Ind, Matrix_Shift, 0, 0, 1, Atom_Mode,
                                            Find_Atom_Contour_Threshold, Circle_Cut=1)
        Atom_Num *= 112.2
        Atom_Nums.append(Atom_Num)
        Atom_Regions.append(Atom_Region)
        Atom_Ind += 1

    # if FitBEC_Switch == 1:
    #     IsBEC_Threshold = np.array([0.015, 0.035])
    #     Od_Low = 0.3
    #     Size_8 = np.shape(Atom_Region_8)
    #     Atom_Size_8 = Cal_AtomSize(Atom_Region_8, [50, 0.17])
    #     Fi_8 = Is_BEC(Atom_Region_8, Atom_Size_8, IsBEC_Threshold)
    #     if Fi_8 == 1:
    #         Od_Lim_8 = 0.3
    #         Ini_Size_8 = 100
    #         Low_Size_8 = 50
    #         x1_8 = 3
    #         x7_8 = x1_8
    #         x1u_8 = 10
    #         x7u_8 = x1u_8
    #     elif Fi_8 == 0.5:
    #         Od_Lim_8 = np.max(Atom_Region_8)*0.5
    #         Od_Lim_8[Od_Lim_8 < Od_Low] = Od_Low
    #         Ini_Size_8 = Atom_Size_8
    #         Low_Size_8 = 10
    #         Ini_Size_8[Low_Size_8 > Ini_Size_8] = Low_Size_8+1
    #         x1_8 = Od_Lim_8/2
    #         x7_8 = x1_8
    #         x1u_8 = Od_Lim_8
    #         x7u_8 = x1u_8
    #     else:
    #         Od_Lim_8 = 5
    #         Ini_Size_8 = Atom_Size_8
    #         Low_Size_8 = 0
    #         x1_8 = Od_Lim_8/2
    #         x7_8 = x1_8
    #         x1u_8 = Od_Lim_8
    #         x7u_8 = x1u_8
    #     x_8 = np.array([x1_8, 0, Ini_Size_8, 0, Ini_Size_8, 0, x7_8, 0])
    #     x_Lim_8 = [[0, -Size_8[0]/2, Low_Size_8, -Size_8[1]/2, Low_Size_8, -np.pi/4, 0, -1], \
    #     [x1u_8, Size_8[0]/2, Size_8[1]*Size_8[1]/4, Size_8[0]/2, Size_8[1]*Size_8[1]/4, np.pi/4, x7u_8, 1]]

    Atom_Nums = np.array(Atom_Nums)
    Atom_Ratio = Atom_Nums / np.sum(Atom_Nums)
    Atom_Ratio = Atom_Ratio.tolist()
    return Atom_Ratio, Atom_Regions, Region_All


def Polylog(z, s, epsilon):
    g = 0
    n = 1
    while np.power(z, n) / np.power(n, s) > epsilon:
        g += np.power(z, n) / np.power(n, s)
    return g
